#include <VM/vm.h>
#include <VM/object.h>
#include <Common/util.h>
#include <chrono>

VM::VM()
{
	resetVM();

	globals.emplace("clock", std::make_shared<NativeFunction>("clock", 0, [](std::vector<Value> args)
															  {
																  using namespace std::chrono;
																  return static_cast<double>(duration_cast<milliseconds>(system_clock::now().time_since_epoch()).count());
															  }));

	globals.emplace("print", std::make_shared<NativeFunction>("print", -1, [](std::vector<Value> args)
															  {
																  for (auto const &arg : args)
																	  std::cout << arg << ' ';
																  std::cout << "\n";
																  return Value{};
															  }));
}

InterpretResult VM::interpret(std::string_view source)
{
	reporter.reset();

	FunctionPtr entry_function = compiler.compile(source, 1);
	if (reporter.count() > 0)
	{
		compiler.reset();
		return InterpretResult::COMPILE_ERROR;
	}

#ifdef _DEBUG
	entry_function->chunk->disassemble("code");
#endif

	valueStack.push_back(entry_function);
	call(entry_function.get(), 0);

	try
	{
		execute();
	}
	catch (const Error &e)
	{
		reporter.report(e, true);
		valueStack.clear();
		return InterpretResult::RUNTIME_ERROR;
	}

	return InterpretResult::OK;
}

void VM::execute()
{
	CallFrame *frame = nullptr;

	auto updateFrame = [&]()
	{
		if (&frames.back() != frame)
		{
			frame = &frames.back();
			chunk = frame->function->chunk.get();
			return true;
		}
		return false;
	};

	updateFrame();

	for (frame->ip; frame->ip < chunk->size(); frame->ip++)
	{

#ifdef _DEBUG
		print("\n\n== stack info ==\n");
		for (auto &val : valueStack)
		{
			print("[ %o ]", val);
		}
		print("\n");
		// debug模式下，执行前先输出字节码
		chunk->disassembleInstruction(frame->ip);
		print("\n");
#endif

		OpCode instruction = static_cast<OpCode>(chunk->getCode(frame->ip));

		switch (instruction)
		{
		case OpCode::CONSTANT:
		{
			u8_t index = chunk->getCode(++frame->ip);
			valueStack.push_back(chunk->getConstant(index));
			break;
		}

		case OpCode::NIL:
		{
			valueStack.emplace_back();
			break;
		}

		case OpCode::TRUE:
		{
			valueStack.emplace_back(true);
			break;
		}

		case OpCode::FALSE:
		{
			valueStack.emplace_back(false);
			break;
		}

		case OpCode::ADD:
		{
			// FIXME 从性能角度，pop出一个，另一个通过back()修改
			// 可以避免容器大小变化导致的性能损失
			try
			{
				valueStack.push_back(pop() + pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
		}

		case OpCode::SUBTRACT:
		{
			try
			{
				valueStack.push_back(pop() - pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
		}

		case OpCode::MULTIPLY:
		{
			try
			{
				valueStack.push_back(pop() * pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
		}

		case OpCode::DIVIDE:
		{
			try
			{
				valueStack.push_back(pop() / pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
		}

		case OpCode::NEGATE:
		{
			valueStack.back() = -peekNumberOperand(frame->ip);
			break;
		}

		case OpCode::NOT:
		{
			valueStack.back() = !isTrue(valueStack.back());
			break;
		}

		case OpCode::EQEQ:
		{
			valueStack.push_back(pop() == pop());
			break;
		}

		case OpCode::NEQ:
		{
			valueStack.push_back(pop() != pop());
			break;
		}

		case OpCode::GT:
		{
			try
			{
				valueStack.push_back(pop() > pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
			break;
		}

		case OpCode::GTE:
		{
			try
			{
				valueStack.push_back(pop() >= pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
		}

		case OpCode::LT:
		{
			try
			{
				valueStack.push_back(pop() < pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
		}

		case OpCode::LTE:
		{
			try
			{
				valueStack.push_back(pop() <= pop());
				break;
			}
			catch (const std::runtime_error &e)
			{
				return runtimeError(e.what());
			}
			break;
		}

		case OpCode::POP:
		{
			valueStack.pop_back();
			break;
		}

		case OpCode::DEFINE_GLOBAL:
		{
			Value val = pop();
			std::string varName = std::get<std::string>(pop());
			globals[varName] = val;
			break;
		}

		case OpCode::GET_GLOBAL:
		{
			// this is unnecessary, in fact, a waste
			// valueStack.back() is the variable
			u8_t index = chunk->getCode(++frame->ip);
			std::string name = std::get<std::string>(chunk->getConstant(index));
			if (auto it = globals.find(name); it == globals.end())
			{
				throw Error(chunk->getLine(frame->ip), "Undefind variable " + name);
			}
			else
			{
				valueStack.back() = it->second;
			}
			break;
		}

		case OpCode::SET_GLOBAL:
		{
			Value val = pop();
			u8_t index = chunk->getCode(++frame->ip);
			std::string name = std::get<std::string>(chunk->getConstant(index));
			if (auto it = globals.find(name); it == globals.end())
			{
				throw Error(chunk->getLine(frame->ip), "Identifier " + name + " is undefined");
			}
			else
			{
				valueStack.back() = it->second = val;
			}
			break;
		}

		case OpCode::GET_LOCAL:
		{
			u8_t index = chunk->getCode(++frame->ip);					   // 相对偏移
			valueStack.push_back(valueStack[frame->stack_offset + index]); // 绝对偏移
			break;
		}

		case OpCode::SET_LOCAL:
		{
			u8_t index = chunk->getCode(++frame->ip);
			valueStack[frame->stack_offset + index] = valueStack.back();
			break;
		}

		case OpCode::JUMP:
		{
			u16_t distance = static_cast<u16_t>((chunk->getCode(++frame->ip) << 8) | chunk->getCode(++frame->ip));
			frame->ip += distance;
			break;
		}

		case OpCode::JUMP_IF_TRUE:
		{
			u16_t distance = static_cast<u16_t>((chunk->getCode(++frame->ip) << 8) | chunk->getCode(++frame->ip));
			if (isTrue(valueStack.back()))
				frame->ip += distance;
			break;
		}

		case OpCode::JUMP_IF_FALSE:
		{
			u16_t distance = static_cast<u16_t>((chunk->getCode(++frame->ip) << 8) | chunk->getCode(++frame->ip));
			if (!isTrue(valueStack.back()))
				frame->ip += distance;
			break;
		}

		case OpCode::LOOP:
		{
			u16_t distance = static_cast<u16_t>((chunk->getCode(++frame->ip) << 8) | chunk->getCode(++frame->ip));
			frame->ip -= distance;
			break;
		}

		case OpCode::CALL:
		{
			u8_t arg_nums = chunk->getCode(++frame->ip);

			// FIXME: NativeFunction返回来frame->ip不应--
			callValue(*(valueStack.rbegin() + arg_nums), arg_nums);

			// 如果call成功调用，则会进入新的CallFrame
			if (updateFrame())
				frame->ip--;

			break;
		}

		case OpCode::RETURN:
		{
			Value return_val = pop(); // 函数返回值

			// discard arguments
			while (valueStack.size() > frame->stack_offset)
				pop();

			frames.pop_back();
			if (frames.size() == 0)
			{
				return;
			}

			// 函数返回值入栈
			valueStack.push_back(return_val);

			// 回到上一层调用
			updateFrame();
			break;
		}

		default:
			break;
		}
	}
}

void VM::resetVM()
{
	valueStack.clear();
	valueStack.reserve(256);
	frames.clear();
	frames.reserve(64);
	chunk = nullptr;
}

void VM::runtimeError(std::string error_message)
{
	for (auto it = frames.rbegin(); it != frames.rend(); it++)
	{
		FunctionObject *function = it->function;
		print("[line %d] in %s\n", function->chunk->getLine(it->ip), function->name);
	}

	throw Error(chunk->getLine(frames.back().ip), error_message);
}

Value VM::pop()
{
	Value top = valueStack.back();
	valueStack.pop_back();
	return top;
}

void VM::callValue(Value callee, int argCount)
{
	if (!isCallable(callee))
		runtimeError("Can only call functions");

	switch (callee.index())
	{
	case FUNCTION:
		return call(std::get<FunctionPtr>(callee).get(), argCount);

	case NATIVE_FUNCTION:
	{
		NativeFuncPtr &func = std::get<NativeFuncPtr>(callee);
		if (func->arity != -1 && argCount != func->arity)
		{
			runtimeError(format("Expected %d arguments, instead got %d", func->arity, argCount));
		}
		Value result = func->function(std::vector<Value>(valueStack.end() - argCount, valueStack.end()));

		// 弹出参数，以及函数本身
		do
		{
			pop();
		} while (argCount-- > 0);

		return valueStack.push_back(result);
	}

	default:
		break;
	}
}

void VM::call(FunctionObject *function, int argCount)
{
	if (argCount != function->arity)
	{
		runtimeError(format("Expected %d arguments, instead got %d", function->arity, argCount));
	}

	frames.emplace_back(function, 0, valueStack.size() - argCount - 1);
}
